{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "getting_started_pytorch.ipynb",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "BYzpwF2FLeCH"
      },
      "source": [
        "# Approximate Inference in Bayesian Deep Learning: Getting Started in Pytorch\n",
        "\n",
        "In this colab we will walk you through downloading the data, running your method and generating a submission for our NeurIPS 2021 competition. In this colab we use the Pytorch framework. For Jax see [this notebook](https://colab.research.google.com/drive/1SJ6waN8DOfby6qW9WWgJ0VLyaNzauhYD?usp=sharing).\n",
        "\n",
        "Useful references:\n",
        "- [Competition website](https://izmailovpavel.github.io/neurips_bdl_competition/)\n",
        "- [Efficient implementation of several baselines in JAX](https://github.com/google-research/google-research/tree/master/bnn_hmc)\n",
        "- [Submission platform](https://competitions.codalab.org/competitions/33647)\n",
        "\n",
        "\n",
        "## Setting up colab\n",
        "\n",
        "Colab provides an easy-to-use environment for working on the competition with access to free computational resources. However, you should also be able to run this notebook locally after installing the required dependencies. If you use colab, please select a `GPU` runtime type.\n",
        "\n",
        "## Preparing the data\n"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "MiSuC1jmhmDp",
        "outputId": "af2f2672-80a2-432c-94cf-e810fc45296b"
      },
      "source": [
        "!git clone https://github.com/izmailovpavel/neurips_bdl_starter_kit\n",
        "\n",
        "import sys\n",
        "import math\n",
        "import matplotlib\n",
        "import numpy as np\n",
        "import copy\n",
        "\n",
        "import torch \n",
        "import torch.nn as nn\n",
        "import torch.nn.functional as F\n",
        "from torch.utils.data import DataLoader, TensorDataset\n",
        "import torch.optim as optim\n",
        "from matplotlib import pyplot as plt\n",
        "\n",
        "sys.path.append(\"neurips_bdl_starter_kit\")\n",
        "import pytorch_models as p_models"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Cloning into 'neurips_bdl_starter_kit'...\n",
            "remote: Enumerating objects: 32, done.\u001b[K\n",
            "remote: Counting objects: 100% (32/32), done.\u001b[K\n",
            "remote: Compressing objects: 100% (20/20), done.\u001b[K\n",
            "remote: Total 32 (delta 13), reused 28 (delta 9), pack-reused 0\u001b[K\n",
            "Unpacking objects: 100% (32/32), done.\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "mrcvRuuJRGhZ"
      },
      "source": [
        "We provide the datasets used in this competition in a public Google Cloud Storage bucket in the `.csv` format. Here we download the data:\n",
        "\n",
        "You can also download the data to your computer by clicking these links:\n",
        "- [CIFAR-10 train features](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_train_x.csv)\n",
        "- [CIFAR-10 train labels](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_train_y.csv)\n",
        "- [CIFAR-10 test features](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_test_x.csv)\n",
        "- [CIFAR-10 test labels](https://storage.googleapis.com/neurips2021_bdl_competition/cifar10_test_y.csv)"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "da-9ecSDnKO9",
        "outputId": "115a104c-cd9e-4717-86ee-d457ddb088aa"
      },
      "source": [
        "!gsutil -m cp -r gs://neurips2021_bdl_competition/cifar10_*.csv ."
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Copying gs://neurips2021_bdl_competition/cifar10_test_y.csv...\n",
            "/ [0/4 files][    0.0 B/  4.4 GiB]   0% Done                                    \rCopying gs://neurips2021_bdl_competition/cifar10_test_x.csv...\n",
            "/ [0/4 files][    0.0 B/  4.4 GiB]   0% Done                                    \rCopying gs://neurips2021_bdl_competition/cifar10_train_x.csv...\n",
            "Copying gs://neurips2021_bdl_competition/cifar10_train_y.csv...\n",
            "/ [4/4 files][  4.4 GiB/  4.4 GiB] 100% Done  70.4 MiB/s ETA 00:00:00           \n",
            "Operation completed over 4 objects/4.4 GiB.                                      \n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zDsZHMQzSGpZ"
      },
      "source": [
        "We can now read the data and convert it into numpy arrays. This cell may take several minutes to run."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "4o8NLOGgnVSV"
      },
      "source": [
        "x_train = np.loadtxt(\"cifar10_train_x.csv\")\n",
        "y_train = np.loadtxt(\"cifar10_train_y.csv\")\n",
        "x_test = np.loadtxt(\"cifar10_test_x.csv\")\n",
        "y_test = np.loadtxt(\"cifar10_test_y.csv\")"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "iQgkl3ZP1QzM"
      },
      "source": [
        "x_train = x_train.reshape((len(x_train), 3, 32, 32))\n",
        "x_test = x_test.reshape((len(x_test), 3, 32, 32))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 304
        },
        "id": "i3nFA3RRyuwd",
        "outputId": "72789385-aeb2-4386-a528-a7ed3aaac292"
      },
      "source": [
        "plt.imshow(x_train[0].reshape(32,32,3))\n",
        "plt.colorbar()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Clipping input data to the valid range for imshow with RGB data ([0..1] for floats or [0..255] for integers).\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "<matplotlib.colorbar.Colorbar at 0x7f6416853ad0>"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 5
        },
        {
          "output_type": "display_data",
          "data": {
            "image/png": "iVBORw0KGgoAAAANSUhEUgAAAS4AAAD8CAYAAADJwUnTAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADh0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uMy4yLjIsIGh0dHA6Ly9tYXRwbG90bGliLm9yZy+WH4yJAAAfAUlEQVR4nO3df5yWdZ3v8ddHGnVRw2hYIpR0kY04UWiTuEkWmkTiKu2Gqbsd2lRalGM/tHM4dkSzHkfNzKPnGC0qK/YQDfpBnKDEFPWwuxFIBIKooIawo0AqWbPmMPM5f1zX1D0z1/d7XzNzz33f1/B+Ph7Xg5nv5/rx5Z7hw3Vd31/m7oiIFMkhta6AiEhPKXGJSOEocYlI4ShxiUjhKHGJSOEocYlI4ShxiUi/MbOFZrbHzJ4IxM3MbjOz7Wa2ycxOynNeJS4R6U93A1Mj8Y8BY9JtFjA/z0mVuESk37j7Y8DLkV3OBe7xxM+Bo81sRLnzvqlSFczjcDM/IhBrjxz3RqDcIse0RmIHIrHYBxKqh0g9cffYP42yPjr5CP/Ny2259n180x+2AK+XFC1w9wU9uNxI4IWS73elZc2xg/qUuMxsKnArMAi4091viO1/BOF7xpbIcTsD5Q2RY2J/632RWGMv6iEykPzm5TZ+8cCoXPsOGvHM6+7e1M9V6qbXicvMBgG3A2eSZMl1Zrbc3bdWqnIiUn0OtEefgSpqN3BsyffHpGVRfXnHdTKw3d2fdfc3gPtJnldFpMAcp9Xbcm0VsBz4z2nr4inAfnePPiZC3x4Vs55NJ3bdycxmkbQWMLgPFxOR6qnUHZeZ3Qd8GGg0s13ANaRvedz928BK4CxgO8kbo3/Ic95+fzmfvqhbAPBWM82hI1LnHKetQtNdufsFZeIOXNbT8/YlcfXq2VRE6l879X2P0ZfEtQ4YY2bHkySs84ELYwfsB1YEYrEDb/3CVzLLdzSHH4W3Lft2+ISvh0OxlsoNgfJY14vY43GslXJZJCbSnxxoG6iJy90PmNkc4AGS7hAL3X1LxWomIjUzkO+4cPeVJC/XRGSAcKC1zqd0r2rPeRGpf44P3EdFERmgHNrqO28pcYlIZ0nP+fqmxCUiXRht0SkMas+qua6i9bIDqv/sD9mBSF8D+8BhwVisy8Pvf/K7YCzU/eJdn3lf5Izh4ePfmnBdMHbpxquCsbGRq108PPv/ohlDwse84+nYfBlhU3hzMLaB32aWxwa4H6xCP5r9vTxfX2eHePd7DvUlK4bl2vc/jfr3xws1yFpEBqakH1d933EpcYlIN+19u2nrd0pcItKJ7rhEpHAco63OZ3VX4hKRbvSoWAEf/Mh5meUT3z+54tdq2Rlu91q7IjT0ObsFrZyG/aFh2/GWw/GcHD5u9tWZ5aPGTwoe44sWB2O3rbgzGLv4G9cHYw2jRmcHWiNtupH5uz995eeCscUv/yh8YJ0YH4n9v6HZ5TsjS0yEBuh/Pm+FIhzjDR9UgTP1n0IkLhGpnqQDqh4VRaRg9HJeRArF3Whz3XGJSMG0645LRIokeTlf36mhvmsnIlVXhJfzhRhkLX3ng24KBy+cEo5NHxGO7Q93HVm9M7vBfvPmHcFjThof7jQwf354kfTFL9XHJLyxLizTOq0r09k3tvw4OzA4sq764LdnFjdNaWL9xvV9es47Yfxg//qyd+ba929P2KhB1iJSe+o5LyKF1K5WRREpkmSQtRKXiBSIY7RqyI+IFIk76oAqIkVj6oAqdaLtS+FYY6Q7xN9E5h7/t3Bs8dLsmTTu/MlXg8fMXTs9GGt+aXW4HnViWyQ2JDbb/qrsWMvo8EIBDaOyy701UomcnAF+x2VmzwOvAW3AgVr05xCRyjsYXs5Pdnct3iIyQDimiQRFpFgcaK3zsYp9vR90YJWZPW5ms7J2MLNZZrbezNb38VoiUhXJgrB5tlrpa1qd5O67zezPgQfNbJu7P1a6g7svABaAxiqKFIFT/z3n+1Q7d9+d/rkH+CFEJkMXkcIYsHdcZnYEcIi7v5Z+PQUIrykPHMpRvI1TMmM7iTV3926JeMlpW3jRDnhPMHLnDfcGYzsDi46MJtxev3bXfcFYBVr5+93gXsb278teQGRbQ3P4oJbsY1pe7/sn5W4VveMys6nArcAg4E53v6FLfBSwCDg63Weuu0en/OjLo+Jw4Idm1nGexe7+0z6cT0TqQPJyvjJDfsxsEHA7cCawC1hnZsvdfWvJbv8DWOLu881sHLASOC523l4nLnd/Fnhvb48XkXpV0TnnTwa2p/kCM7sfOBcoTVwOvDn9egjw7+VOWt9tniJSdcnL+dzvrxq79BhYkDbIdRgJvFDy/S5gYpdzXEvSO+G/AEcAHyl3USUuEemmBz3n91VgxMwFwN3ufrOZ/RXwHTN7t7u3hw5Q4hKRTircc343dJq3+pi0rNRFwFQAd/83MzscaAT2hE5a3501RKQm2jkk15bDOmCMmR1vZocC5wPLu+yzEzgDwMzeBRwO7I2dtKp3XO/E+XGgYbuFk4LHjSYwFD7SyLyCe4Kx7HkLOuoRFlo2Ym3kmFgsJtZ8Hqvjw725WGNkUYaIHcsXB2MTj8r+tFZFuraEl9GoH6MjsezlQRLj+Y9gbPUN8zLLNw8J/6THTsme0eONV6L/3nNxh9b2ytzTuPsBM5sDPEDS1WGhu28xs+uA9e6+HLgCuMPMvkDyiu3TXmYVHz0qikgnyaNi5R7G0j5ZK7uUzSv5eitwak/OqcQlIt3Usld8HkpcItJJD7tD1IQSl4h0UdlHxf6gxCUi3WjO+RKHvu+djFrfi7nDHwqMANgRboeaviK8IPr05UsjF9sfiWUPbI21J22OtCbF5iiPxa4e+t/CwSWBpeo/EmqZBe45Oxz7ypPBULgdGBpei32OPTc5Eov9RoXaS+f28nzhWeC7dwcvFWtxvK3tkczywS+Hj1lz/y8yyysxFXHSqqjlyUSkQDR1s4gUkh4VRaRQ1KooIoWkVkURKRR344ASl4gUjR4VK+GMt/esHGDWByMn/O99qk5PjO9l7O0W/sWZ9/KNwdi3LslusJ/9XLhB/qOWPWAXYBUPBmN+yzPBGK3Z86X72knhY9ZG5r7ftTkc45fBSKhTRgNvCx4zgxeDsUinkj54c6A81vkiexB7E0/0uTZ6xyUihaTEJSKFon5cIlJI6sclIoXiDgcqNJFgf1HiEpFu9KgoIoWid1x1qvVX4dj+yKQGjadVuCJd1zopEVl8PerS567KLj8+u7xPJp0QjjWFYrFuKpUX61AQ0j9dHqqkqa8rhSW8zhNX2QdZM1toZnvM7ImSsqFm9qCZPZP++Zb+raaIVFM7lmurlTxv4O4mXfOsxFzgIXcfAzxEfHojESkQ9+QdV56tVsomLnd/DOg6pdm5wKL060XA9ArXS0RqxmhrPyTXViu9fcc13N07XsO8CAwP7Whms4BZAKNGFfrtgchBo/DvuMpJF24MLt7o7gvcvcndm4YNG9bXy4lIP+sYq1joR8WAl8xsBED6557KVUlEasqT91x5tlrp7aPicmAmcEP6548qVqNKWR/+VC85+7xgbOdL4Y4IjUdlL70wedq04DEXXj4zGPvoByLdCQpgxftDi4fAtDnXZQf+d/Vm5pDeK/yQHzO7D/gw0Ghmu4BrSBLWEjO7CPg1EM4EIlIonr6cr2dlE5e7XxAInVHhuohInajlY2AeB2XPeRGJq/dWRSUuEekkefGuxCUiBaNB1iJSOHrH1Y/mf/FbwdiMSBeFiZPCy1Qs+v73whd8LVC+JtwtYPZ9lwRjMz80ORhb++g9wdj1wUg1lwGBaR+bHYztn784s3zx/wnPUjF7Y+Rfy3tzV6sTs9DPOryohNf7v9p+5hjtRW9VFJGDT72n7vpOqyJSfenL+TxbHmY21cyeMrPtZpY5k4yZnWdmW81si5ll366X0B2XiHRXoVsuMxsE3A6cCewC1pnZcnffWrLPGJK3HKe6+ytm9uflzqs7LhHppoJ3XCcD2939WXd/A7ifZFqsUpcAt7v7K8m1vezYZyUuEenEgfZ2y7WRDAVcX7LN6nK6kcALJd/vSstK/SXwl2b2L2b2czPrOnFpN3pUFJHOHMjfj2ufu/d1ovs3AWNIxkQfAzxmZuPd/dXYAXVvzT8/mVl+6S2XBY+ZfcWlkdgVwdil378mf8VSS3dFulBEzJh+Ybgeke4QG3p1td65PBacODYYGjL76szyeeeEX18smhD+xzLq+E8EY0ufWxaMwYFILJtZuB4HS1eJCv41dwPHlnx/DN2XidkFrHX3VuA5M3uaJJGtC51Uj4oi0p3n3MpbB4wxs+PN7FDgfJJpsUotI7nbwswaSR4dn42dtBB3XCJSTfm7OpTj7gfMbA7wADAIWOjuW8zsOmC9uy9PY1PMbCvQBnzJ3X8TO68Sl4h0V8EnYndfCazsUjav5GsHvphuuShxiUhnDt6uQdYiUjhKXH328c98vOcHDY7FjgiGJnJqMLaWf8ksb4h8jDt/kN0iCrCzOTy/fczSXh3VO7EWzM3Xfi4YG//HZTc72xc5361Dwy2Hy57bHIxNIbzs3ar4O14JqfPG00IkLhGpMiUuESmUnnVArQklLhHppt772SpxiUh3alUUkaIx3XGJSKHkH85TM4VIXOOPyZ43fPWup4LHHDF0SDA2e0J4cPOooY3B2NqXs8tbIwN5V61aE4zNnDYlGBvy9WCI/eFQxY2OxFZH/t4tQ7PLfz76H4PHTLx3fvhil8wMhhY/WnbCzIpZ893fB2OTPhnuZlMsVvcv58sOsjazhWa2x8yeKCm71sx2m9nGdDurf6spIlVVuUHW/SLP7BB3A1kTe93i7hPSbWVGXESKqj3nViNlHxXd/TEzO67/qyIidaEA/bj6Mh/XHDPblD5KviW0k5nN6pjWde/evX24nIhUi3m+rVZ6m7jmk7y3nQA0AzeHdnT3Be7e5O5Nw4YN6+XlRKSqBsA7rm7c/SV3b3P3duAOkpU8RESqolfdIcxshLt3TG3wcWLrmVfAwy9kz4dw+tumB49paAmfb/So0LLsMGPGjGCs9YbsjgjLXnskeMy2DTuCsYaZ4VkNXv2nrcHY6Z8dF4ytDkZ654r3/89gbFtz+GotMydllk/+2rzM8nIaCHdTWdGLeeV7a+3qtcHYpE+eXrV69LfCd0A1s/tI5oNuNLNdwDXAh81sAsnN4vPAZ/uxjiJSTU7xh/y4+wUZxXf1Q11EpF4U/Y5LRA4+hX9UFJGDkBKXiBSOEpeIFEmtO5fmUejE9fCLsaXXw7566jeCseZ9O4Oxu+/9Vmb50eeEuyfs2xGey2HbsnB3grGzw03rkyM/ttiMDb2xoiVc//GjwnNH7AwsVtLyyqvBY1qWhruOnPfoN4OxmPGdVn//k8280Kvz7d8Z/v0YUIreqigiBx/dcYlI8ShxiUih6B2XiBSSEpeIFI3VcJLAPPoyH5eISE0clHdc8/71S8HYvWP/VzA2ZMq7MstvPSc8g8KIIeGZKHZs2ByMtVy5IXxcFWdDaG1tCMZGDR4RjN15W/YCFhtWhP9et/3rj/JXLKdNnt194YQ3TQ4es6PtkWBs/Ojw33lA0aOiiBSKXs6LSCEpcYlI4ShxiUiRGGpVFJGiybnCT973YGY21cyeMrPtZjY3st/fmpmbWVO5cx6Ud1zhdjKYNDnc2hQ68OLZFwcPmX/zbcHYtjVrgrHm15uDsRXBSOXNe/prwVgr/xiMXX159oD0DWvDrYrQu1bFJeff1ONjbv3G7GDs7C88EozNmDGxx9cqpAo9KprZIOB24ExgF7DOzJa7+9Yu+x0FfA4IT+pfQndcItJd5ZYnOxnY7u7PuvsbwP3AuRn7fRW4EXg9z0mVuESkmx48KjZ2LPicbrO6nGokdJpDaFda9qdrmZ0EHOvuuR8mDspHRREpI/+j4j53L/tOKsTMDgG+CXy6J8cpcYlIZ17RVsXd0Gk2x2PSsg5HAe8GHjEzgLcBy83sHHdfHzqpEpeIdFe5flzrgDFmdjxJwjofuPCPl3HfD39a7dfMHgGujCUt0DsuEclQqe4Q7n4AmAM8ADwJLHH3LWZ2nZmd09v65VnJ+ljgHmA4SR5e4O63mtlQ4LvAcSSrWZ/n7q/0tiLV9IZXtlvw4LHDgrEN28LN/82vh+cvbyY813vMhZyZWb6KB4PH7OvVleDi668Pxkb9zdGZ5WN/cFLwmEuWhwe/x0yaNr3Hx0z7/HnB2JQvfDJysey/14BTwX8i7r4SWNmlbF5g3w/nOWeeO64DwBXuPg44BbjMzMYBc4GH3H0M8FD6vYgUXd6uEDUcFlQ2cbl7s7tvSL9+jeR2byRJX4xF6W6LgJ7/tycidceobM/5/tCjl/NmdhxwIknv1uHu3tG9+0WSR0kRGQDqfVqb3C/nzexI4PvA5939t6Uxdw/eOJrZrI7OaXv37u1TZUWkSor+qAhgZg0kSeted/9BWvySmY1I4yOAPVnHuvsCd29y96Zhw8IvsUWkjhQ9cVnSK+wu4El3L11OeDkwM/16Jr0dISsi9aXCs0P0hzzvuE4FPgVsNrONadlVwA3AEjO7CPg1EG5fHuBWL3s4GJs2I9xmMXnU2GDs9C+EP84G3hyM3btxVXZgdPAQmje8GowNHhxu/h/Si4Eeg6dVvjvBvBvuDMbu+Psbeny+B372f8PBg6XnY52/4yqbuNx9DUlDQ5YzKlsdEakH9T6RoIb8iEg39d6qqMQlIp3V+MV7HkpcItKdEpeIFElHz/l6psQlIt1Ye31nLiWuHlhx048zyxctzl5uHmDJL8OxmMuXXRGMXfpoeBaFfYEFPRqPDF9rxGlVnPHgsMqf8s4tNwZjd9Dz7hCccXYfajMA6B2XiBSRHhVFpHiUuESkaHTHJSLFo8QlIoVS2VV++oUSl4h0on5cA8zZ//WvM8s3/eSZil9r9nVXBmOXfijcHWLzjuyZHiaPO0gWeehi6YLHM8tnzHpflWtSMBVeUKbSlLhEpBvdcYlIsagDqogUkV7Oi0jhKHGJSLE4ejlfNBt++mSPjxk/9YTKV+S03h22ZumKzPLJf/13fahM7V3+me8GY7ctDP/dVqzK/jzUqhinl/MiUjxKXCJSJOqAKiLF466JBEWkgOo7bylxiUh3elQUkWJxoOiPimZ2LHAPMJzkr7TA3W81s2uBS4C96a5XufvK/qpotexrCUzaDjy8vPKDqXvn2GBkxZq1meVXUx/dIbY99PtgbPzws4KxW+86LxibPmNaMLZqRXZ3CCmjvvMWh+TY5wBwhbuPA04BLjOzcWnsFnefkG6FT1oikjDPt+U6l9lUM3vKzLab2dyM+BfNbKuZbTKzh8zsHeXOWTZxuXuzu29Iv34NeBIYma/KIlJE1u65trLnMRsE3A58DBgHXFBy49Phl0CTu78H+B7w9XLnzXPHVVqJ44ATgY7nkTlpllxoZm/pyblEpE55D7byTga2u/uz7v4GcD9wbqfLua9295b0258Dx5Q7ae7EZWZHAt8HPu/uvwXmA6OBCUAzcHPguFlmtt7M1u/duzdrFxGpI0kHVM+1AY0d/77TbVaX040EXij5fhfxJ7aLgJ+Uq2OuVkUzayBJWve6+w8A3P2lkvgdQOZqqe6+AFgA0NTUVOev/EQEgPyzQ+xz96ZKXNLM/h5oAj5Ubt88rYoG3AU86e7fLCkf4e7N6bcfB57oXXVFpN5Y5WaH2E3nZvBj0rLO1zP7CPBl4EPu/odyJ81zx3Uq8Clgs5ltTMuuInnJNoHkSfd54LM5zlX35t92ZzD2w0d6sZx7P1jyk/AMFn93Xn10ewj54EdPD8b2HsjuylHO5KlHBGNrVzf26pwHtcrOgLoOGGNmx5MkrPOBC0t3MLMTgX8Cprr7njwnLZu43H0NyWNvV+r+IDIgVW6sorsfMLM5wAPAIGChu28xs+uA9e6+HLgJOBJYmjzgsdPdz4mdVz3nRaS7Ck4kmPbxXNmlbF7J1x/p6TmVuESkMy0IKyKFpKmbRaRw6jtvKXGJSHfWXt/PikpcXWx+9LZItD66Q8yINP+f99qPqliTsNU3ZfZHprFtc+Uv9rtwaOnXL8ksn3vjjsrXY6BwetIBtSaUuESkE8Mr2QG1XyhxiUh3SlwiUjhKXCJSKHrHJSJFpFZFESkY16NiZ23Aq4HYvshxg7OLX4kcs21nMDTv7PDCC438R/icP707u7wlfC0aA3UHGBJemIOW8HFLF/V8AYj5J44KxmbPvTgYa90X/oxb9g8Jxi758tcyy69+f3ihDx66PhjavyHcjWL+/GXB2LbQz/OfvxWux8zJ4dghkZ8ZsZkojo7EeqMfE4ujxCUiBVTfT4pKXCLSnfpxiUjxKHGJSKG4Q1t9PysqcYlId7rjEpHCUeIq8eLTcNNHs2Mt2cUA7MwO7tvRnFkOsPjR3wRj4cbzeKeM8z72D5nl4U4BycKTwdjwcKwl0up+267ISQOaN74QjC2+5JpgbEikhX/Hc5FYoHzZunA9Rlx6VTC2P/L7MTjyedw9NBDYsSp80M1Lw7EhkYo0Rn4TRoW7o4SvFTnfmLGBQOw3OCcHKjTnfH/RHZeIdOHgesclIkXi6OW8iBSQ3nGJSOEocYlIsQyAQdZmdjjwGHBYuv/33P2adEnt+4G3Ao8Dn3L3N+JnewMIDEhuDrfWtAZaFZc9eiB4TKwN5+EPhGM7I+Old4ZarwaFjxkxIhxriIy/Hhw57upIi2NzoKH1wsi44VjLIUPeGo61hg+cvuap7EMilxodaigrZ8rbIsFAHYdEahJrJh4b+cEQG4AdaY0M/SJEfj/4XaD1sD38byI3B+p8WptDcuzzB+B0d38vMAGYamanADcCt7j7CcArwEX9V00RqSr3fFuNlE1cnuhYR6Uh3Rw4HfheWr4ImN4vNRSRKkuH/OTZaiTPHRdmNsjMNgJ7gAdJ+he+6u4d96W7gJH9U0URqSoH9/ZcW63kejnv7m3ABDM7GvghkPtthJnNAmYBjDo6V54UkVqr857zPcok7v4qsBr4K+BoM+tIfMcAuwPHLHD3JndvGnaEEpdIIRT9HZeZDUvvtDCzPwPOBJ4kSWCfSHebCdTHEsoi0jfuSatinq1G8jwqjgAWmdkgkkS3xN1/bGZbgfvN7GvAL4G7yp7pyD+DieOzY0P2Bw9rCDRBXzwisoz6+MB1ACaGn3Qbd24Lxk5qCfQ1aI2NAA6Hok3kg8MdOsY2Rp7UdwbqGJvfPtb8H23iD3cNGDU58PNsCP+cGRyJtUS6L4yfEo4NCdRxf2QwcqweI2MdbWKdPWK/CKHPOPZGJnC+Q+6LHNMDRe/H5e6bgBMzyp8FTu6PSolILTne1lbrSkSp57yIdKZpbUSkkOp8Whs184lIJw54u+fa8jCzqWb2lJltN7O5GfHDzOy7aXytmR1X7pxKXCLSmacTCebZykgb9W4HPgaMAy4ws3FddrsIeCUdPngLyXDCKCUuEenG29pybTmcDGx392fTSRjuB87tss+5JMMGIRlGeIaZWeyk5lVs9jSzvcCv028bqcgE2X2menSmenRWtHq8w92H9eVCZvZTgtNqdHM48HrJ9wvcfUHJuT4BTHX3i9PvPwVMdPc5Jfs8ke6zK/1+R7pP8O9b1ZfzpR+oma1396ZqXj+L6qF6qB6dufvUalynL/SoKCL9aTdwbMn3WcMD/7hPOoxwCBBepgslLhHpX+uAMWZ2vJkdCpwPLO+yz3KSYYOQDCN82Mu8w6plP64F5XepCtWjM9WjM9WjD9z9gJnNAR4gmSt4obtvMbPrgPXuvpxkuOB3zGw78DJJcouq6st5EZFK0KOiiBSOEpeIFE5NEle5IQBVrMfzZrbZzDaa2foqXnehme1J+690lA01swfN7Jn0z7fUqB7Xmtnu9DPZaGZnVaEex5rZajPbamZbzOxzaXlVP5NIPar6mZjZ4Wb2CzP7VVqPr6Tlx6dDYranQ2QO7c961DV3r+pG8oJuB/AXwKHAr4Bx1a5HWpfngcYaXPc04CTgiZKyrwNz06/nAjfWqB7XAldW+fMYAZyUfn0U8DTJ8JCqfiaRelT1MwEMODL9ugFYC5wCLAHOT8u/Dcyu5s+pnrZa3HHlGQIwoLn7YyStJ6VKhz1UZdWkQD2qzt2b3X1D+vVrJDPsjqTKn0mkHlXlCa2sFVGLxDUSeKHk+1quEOTAKjN7PF3Uo5aGu3vH9KUvAsNrWJc5ZrYpfZTs90fWUunMACeS3GXU7DPpUg+o8meilbXiDvaX85Pc/SSSkeuXmdlpta4QJP/jkiTVWpgPjCZZ/LcZuLlaFzazI4HvA59399+Wxqr5mWTUo+qfibu3ufsEkp7mJ9ODlbUOBrVIXHmGAFSFu+9O/9xDsuxaLaeifsnMRgCkf+6pRSXc/aX0H007cAdV+kzMrIEkWdzr7j9Ii6v+mWTVo1afSXrtHq+sdTCoReLKMwSg35nZEWZ2VMfXwBTgifhR/ap02EPNVk3qSBSpj1OFzySdwuQu4El3/2ZJqKqfSage1f5MTCtrlVeLFgHgLJIWmx3Al2tUh78gadH8FbClmvUA7iN55GgleVdxEfBW4CHgGeBnwNAa1eM7wGZgE0niGFGFekwieQzcBGxMt7Oq/ZlE6lHVzwR4D8nKWZtIkuS8kt/ZXwDbgaXAYdX6na23TUN+RKRwDvaX8yJSQEpcIlI4SlwiUjhKXCJSOEpcIlI4SlwiUjhKXCJSOP8fV4asJwmR+wEAAAAASUVORK5CYII=\n",
            "text/plain": [
              "<Figure size 432x288 with 2 Axes>"
            ]
          },
          "metadata": {
            "tags": [],
            "needs_background": "light"
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "z8JIisIsSeOO"
      },
      "source": [
        "Finally, we define a `torch.utils.data.TensorDataset` for the train and test datasets."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "RZ3gwA1jmAJt"
      },
      "source": [
        "trainset = TensorDataset(torch.Tensor(x_train), torch.Tensor(y_train))\n",
        "testset = TensorDataset(torch.Tensor(x_test), torch.Tensor(y_test))"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ui5dKavnSoQ9"
      },
      "source": [
        "## Model and losses\n",
        "\n",
        "We provide the code for all the models used in the competition in the `neurips_bdl_starter_kit/torch_models.py` module. Here, we will load a ResNet-20 model with filter response normalization (FRN) and swish activations. The models are implemented in `pytorch`. \n",
        "\n",
        "We also define the cross-entropy likelihood (`log_likelihood_fn`) and Gaussian prior (`log_prior_fn`), and the corresponding posterior log-density (`log_posterior_fn`). The `log_posterior_wgrad_fn` computes the posterior log-density and its gradients with respect to the parameters of the model.\n",
        "\n",
        "The `evaluate_fn` function computes the accuracy and predictions of the model on a given dataset; we will use this function to generate the predictions for our submission."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "gwMaZjf9yDTE",
        "outputId": "19c0f9ef-bb73-4427-ab61-298baf1a7651"
      },
      "source": [
        "!nvidia-smi"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Thu Jul 15 16:35:54 2021       \n",
            "+-----------------------------------------------------------------------------+\n",
            "| NVIDIA-SMI 470.42.01    Driver Version: 460.32.03    CUDA Version: 11.2     |\n",
            "|-------------------------------+----------------------+----------------------+\n",
            "| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |\n",
            "| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |\n",
            "|                               |                      |               MIG M. |\n",
            "|===============================+======================+======================|\n",
            "|   0  Tesla T4            Off  | 00000000:00:04.0 Off |                    0 |\n",
            "| N/A   43C    P8    10W /  70W |      0MiB / 15109MiB |      0%      Default |\n",
            "|                               |                      |                  N/A |\n",
            "+-------------------------------+----------------------+----------------------+\n",
            "                                                                               \n",
            "+-----------------------------------------------------------------------------+\n",
            "| Processes:                                                                  |\n",
            "|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |\n",
            "|        ID   ID                                                   Usage      |\n",
            "|=============================================================================|\n",
            "|  No running processes found                                                 |\n",
            "+-----------------------------------------------------------------------------+\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "1kbD51slyAU-",
        "outputId": "f60f133b-57f2-45e0-accd-478aabcee6d6"
      },
      "source": [
        "net_fn =  p_models.get_model(\"resnet20_frn_swish\", data_info={\"num_classes\": 10})\n",
        "if torch.cuda.is_available():\n",
        "    print(\"GPU available!\")\n",
        "    net_fn = net_fn.cuda()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "GPU available!\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "EI5zfRQW0VN6"
      },
      "source": [
        "prior_variance = 5.\n",
        "\n",
        "def log_likelihood_fn(model_state_dict, batch):\n",
        "    \"\"\"Computes the log-likelihood.\"\"\"\n",
        "    x, y = batch\n",
        "    if torch.cuda.is_available():\n",
        "      x = x.cuda()\n",
        "      y = y.cuda()\n",
        "    net_fn.zero_grad()\n",
        "    for name, param in net_fn.named_parameters():\n",
        "        param.data = model_state_dict[name]\n",
        "    logits = net_fn(x)\n",
        "    num_classes = logits.shape[-1]\n",
        "    labels = F.one_hot(y.to(torch.int64), num_classes= num_classes)\n",
        "    softmax_xent = torch.sum(labels * F.log_softmax(logits))\n",
        "\n",
        "    return softmax_xent\n",
        "\n",
        "\n",
        "def log_prior_fn(model_state_dict):\n",
        "    \"\"\"Computes the Gaussian prior log-density.\"\"\"\n",
        "    n_params = sum(p.numel() for p in model_state_dict.values()) \n",
        "    exp_term = sum((-p**2 / (2 * prior_variance)).sum() for p in model_state_dict.values() )\n",
        "    norm_constant = -0.5 * n_params * math.log((2 * math.pi * prior_variance))\n",
        "    return exp_term + norm_constant\n",
        "\n",
        "\n",
        "def log_posterior_fn(model_state_dict, batch):\n",
        "    log_lik = log_likelihood_fn(model_state_dict, batch)\n",
        "    log_prior = log_prior_fn(model_state_dict)\n",
        "    return log_lik + log_prior\n",
        "\n",
        "\n",
        "def get_accuracy_fn(batch, model_state_dict):\n",
        "    x, y = batch\n",
        "    if torch.cuda.is_available():\n",
        "      x = x.cuda()\n",
        "      y = y.cuda()\n",
        "    # get logits \n",
        "    net_fn.eval()\n",
        "    with torch.no_grad():\n",
        "      for name, param in net_fn.named_parameters():\n",
        "          param.data = model_state_dict[name]\n",
        "      logits = net_fn(x)\n",
        "    net_fn.train()\n",
        "    # get log probs \n",
        "    log_probs = F.log_softmax(logits, dim=1)\n",
        "    # get preds \n",
        "    probs = torch.exp(log_probs)\n",
        "    preds = torch.argmax(logits, dim=1)\n",
        "    accuracy = (preds == y).float().mean()\n",
        "    return accuracy, probs\n",
        "\n",
        "\n",
        "def evaluate_fn(data_loader, model_state_dict):\n",
        "    sum_accuracy = 0\n",
        "    all_probs = []\n",
        "    for x, y in data_loader:       \n",
        "        batch_accuracy, batch_probs = get_accuracy_fn((x, y), model_state_dict)\n",
        "        sum_accuracy += batch_accuracy.item()\n",
        "        all_probs.append(batch_probs)\n",
        "    all_probs = torch.cat(all_probs, dim=0)\n",
        "    return sum_accuracy / len(data_loader), all_probs"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "KEZnIWNmUEwE"
      },
      "source": [
        "## Optimization and training \n",
        "\n",
        "In this colab we train an approximate maximum-a-posteriori (MAP) solution as our submission for simplicity. You can find efficient implementations of more advanced baselines in jax [here](https://github.com/google-research/google-research/tree/master/bnn_hmc).\n",
        "\n",
        "We use SGD with momentum. You can adjust the hyper-parameters or switch to a different optimizer by changing the code below.\n",
        "\n",
        "We run training for 5 epochs, which can take several minutes to complete. Note that in order to achieve good results you need to run the method substantially longer and tune the hyper-parameters."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "eKxlP6Iu1jmH"
      },
      "source": [
        "batch_size = 100\n",
        "test_batch_size = 100\n",
        "num_epochs = 10\n",
        "momentum_decay = 0.9\n",
        "lr = 0.001\n",
        "\n",
        "train_loader = DataLoader(trainset, batch_size=batch_size, shuffle=True)\n",
        "test_loader = DataLoader(testset, batch_size=test_batch_size, shuffle=False)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "fLGw_AVJtjy3",
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "outputId": "ec69ec42-abe0-41a1-b290-dfd260ade08d"
      },
      "source": [
        "epoch_steps = len(train_loader)\n",
        "\n",
        "optimizer = optim.SGD(net_fn.parameters(), lr=lr, momentum=momentum_decay)\n",
        "\n",
        "for epoch in range(num_epochs):\n",
        "  running_loss = 0.0\n",
        "  total_loss = 0.0\n",
        "  for i, data in enumerate(train_loader):\n",
        "    optimizer.zero_grad()\n",
        "    model_state_dict = copy.deepcopy(net_fn.state_dict())\n",
        "    loss = - log_posterior_fn(model_state_dict, data)\n",
        "    loss.backward()\n",
        "    optimizer.step()\n",
        "    running_loss += loss.item()\n",
        "    total_loss += loss.item()\n",
        "    if i % 100 == 99:    # print every 100 mini-batches\n",
        "      print('[%d, %5d] loss: %.3f' %\n",
        "            (epoch + 1, i + 1, running_loss / 100))\n",
        "      running_loss = 0.0\n",
        "  model_state_dict = copy.deepcopy(net_fn.state_dict())\n",
        "  test_acc, all_test_probs = evaluate_fn(test_loader, model_state_dict)\n",
        "  print(\"Epoch {}\".format(epoch))\n",
        "  print(\"\\tAverage loss: {}\".format(total_loss / epoch_steps))\n",
        "  print(\"\\tTest accuracy: {}\".format(test_acc))"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:16: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument.\n",
            "  app.launch_new_instance()\n"
          ],
          "name": "stderr"
        },
        {
          "output_type": "stream",
          "text": [
            "[1,   100] loss: 484943.916\n",
            "[1,   200] loss: 484928.159\n",
            "[1,   300] loss: 484923.939\n",
            "[1,   400] loss: 484918.999\n",
            "[1,   500] loss: 484916.494\n",
            "Epoch 0\n",
            "\tAverage loss: 484926.3013125\n",
            "\tTest accuracy: 0.2085999943315983\n",
            "[2,   100] loss: 484915.014\n",
            "[2,   200] loss: 484910.846\n",
            "[2,   300] loss: 484910.031\n",
            "[2,   400] loss: 484910.391\n",
            "[2,   500] loss: 484906.492\n",
            "Epoch 1\n",
            "\tAverage loss: 484910.55475\n",
            "\tTest accuracy: 0.29139999285340307\n",
            "[3,   100] loss: 484909.698\n",
            "[3,   200] loss: 484906.397\n",
            "[3,   300] loss: 484905.250\n",
            "[3,   400] loss: 484906.953\n",
            "[3,   500] loss: 484908.406\n",
            "Epoch 2\n",
            "\tAverage loss: 484907.340625\n",
            "\tTest accuracy: 0.3525999927520752\n",
            "[4,   100] loss: 484907.397\n",
            "[4,   200] loss: 484911.037\n",
            "[4,   300] loss: 484912.406\n",
            "[4,   400] loss: 484915.801\n",
            "[4,   500] loss: 484919.463\n",
            "Epoch 3\n",
            "\tAverage loss: 484913.22075\n",
            "\tTest accuracy: 0.3986999899148941\n",
            "[5,   100] loss: 484916.747\n",
            "[5,   200] loss: 484920.216\n",
            "[5,   300] loss: 484925.258\n",
            "[5,   400] loss: 484931.242\n",
            "[5,   500] loss: 484930.337\n",
            "Epoch 4\n",
            "\tAverage loss: 484924.759875\n",
            "\tTest accuracy: 0.4132999897003174\n",
            "[6,   100] loss: 484933.309\n",
            "[6,   200] loss: 484936.083\n",
            "[6,   300] loss: 484938.059\n",
            "[6,   400] loss: 484946.466\n",
            "[6,   500] loss: 484947.965\n",
            "Epoch 5\n",
            "\tAverage loss: 484940.3765\n",
            "\tTest accuracy: 0.43829998791217806\n",
            "[7,   100] loss: 484949.258\n",
            "[7,   200] loss: 484954.157\n",
            "[7,   300] loss: 484956.161\n",
            "[7,   400] loss: 484961.198\n",
            "[7,   500] loss: 484963.157\n",
            "Epoch 6\n",
            "\tAverage loss: 484956.7863125\n",
            "\tTest accuracy: 0.4489999884366989\n",
            "[8,   100] loss: 484964.502\n",
            "[8,   200] loss: 484971.585\n",
            "[8,   300] loss: 484974.297\n",
            "[8,   400] loss: 484978.401\n",
            "[8,   500] loss: 484985.867\n",
            "Epoch 7\n",
            "\tAverage loss: 484974.9303125\n",
            "\tTest accuracy: 0.47729998916387556\n",
            "[9,   100] loss: 484985.074\n",
            "[9,   200] loss: 484990.870\n",
            "[9,   300] loss: 484994.815\n",
            "[9,   400] loss: 485000.512\n",
            "[9,   500] loss: 485005.290\n",
            "Epoch 8\n",
            "\tAverage loss: 484995.312125\n",
            "\tTest accuracy: 0.4798999884724617\n",
            "[10,   100] loss: 485006.964\n",
            "[10,   200] loss: 485013.146\n",
            "[10,   300] loss: 485019.063\n",
            "[10,   400] loss: 485021.846\n",
            "[10,   500] loss: 485027.199\n",
            "Epoch 9\n",
            "\tAverage loss: 485017.6435\n",
            "\tTest accuracy: 0.4944999879598618\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "L9KjyHdz7rSw"
      },
      "source": [
        "all_test_probs = np.asarray(all_test_probs.cpu())"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "QKDfvlHzxulX"
      },
      "source": [
        "## Evaluating metrics"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "EiQcDfma00Qn"
      },
      "source": [
        "The starter kit comes with a script that can compute the agreement and total variation distance metrics used in the competition."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "XNr_UlytxxB8"
      },
      "source": [
        "import metrics"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "zu-gQzoY05P2"
      },
      "source": [
        "We can load the HMC reference predictions from the starter kit as well."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "gdIagtdhxyRu"
      },
      "source": [
        "with open('neurips_bdl_starter_kit/data/cifar10/probs.csv', 'r') as fp:\n",
        "  reference = onp.loadtxt(fp)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "80T57WFH09Eh"
      },
      "source": [
        "Now we can compute the metrics!"
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TCSC3Svbx-0M",
        "outputId": "d0c25a26-2ec0-4b5c-8b26-d64f8b7f7653"
      },
      "source": [
        "metrics.agreement(all_test_probs, reference)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.5575"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 15
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "fJOvBOe0zmhT",
        "outputId": "f650c574-63d4-4e42-c7e2-73316f0f9ce2"
      },
      "source": [
        "metrics.total_variation_distance(all_test_probs, reference)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "0.5070471073085077"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 17
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Y7nYa4w3VmnO"
      },
      "source": [
        "## Preparing the submission\n",
        "\n",
        "Once you run the code above, `all_test_probs` should contain an array of size `10000 x 10` where the rows correspond to test datapoints and columns correspond to classes."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "atqsSbjBInpg",
        "outputId": "e09d6459-0b67-4d1a-8b73-ac0be9a4acc2"
      },
      "source": [
        "all_test_probs.shape"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "execute_result",
          "data": {
            "text/plain": [
              "(10000, 10)"
            ]
          },
          "metadata": {
            "tags": []
          },
          "execution_count": 13
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FB2LhtfUW1ar"
      },
      "source": [
        "Now, we need to save the array as `cifar10_probs.csv` and create a zip archive with this file."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "h0CtqrJCI2-F",
        "outputId": "50e9895a-e345-4dc9-f899-c1504b35ff04"
      },
      "source": [
        "np.savetxt(\"cifar10_probs.csv\", all_test_probs)\n",
        "\n",
        "!zip submission.zip cifar10_probs.csv"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "  adding: cifar10_probs.csv (deflated 63%)\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "f-3deTsUNcqI"
      },
      "source": [
        "Finally, you can download the submission by running the code below. If the download doesn't start, check that your browser did not block it automatically."
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 17
        },
        "id": "ieCUqCPjJP0k",
        "outputId": "918619c7-c4f1-4bab-eb49-0ff7da7b1683"
      },
      "source": [
        "from google.colab import files\n",
        "files.download('submission.zip') "
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "\n",
              "    async function download(id, filename, size) {\n",
              "      if (!google.colab.kernel.accessAllowed) {\n",
              "        return;\n",
              "      }\n",
              "      const div = document.createElement('div');\n",
              "      const label = document.createElement('label');\n",
              "      label.textContent = `Downloading \"${filename}\": `;\n",
              "      div.appendChild(label);\n",
              "      const progress = document.createElement('progress');\n",
              "      progress.max = size;\n",
              "      div.appendChild(progress);\n",
              "      document.body.appendChild(div);\n",
              "\n",
              "      const buffers = [];\n",
              "      let downloaded = 0;\n",
              "\n",
              "      const channel = await google.colab.kernel.comms.open(id);\n",
              "      // Send a message to notify the kernel that we're ready.\n",
              "      channel.send({})\n",
              "\n",
              "      for await (const message of channel.messages) {\n",
              "        // Send a message to notify the kernel that we're ready.\n",
              "        channel.send({})\n",
              "        if (message.buffers) {\n",
              "          for (const buffer of message.buffers) {\n",
              "            buffers.push(buffer);\n",
              "            downloaded += buffer.byteLength;\n",
              "            progress.value = downloaded;\n",
              "          }\n",
              "        }\n",
              "      }\n",
              "      const blob = new Blob(buffers, {type: 'application/binary'});\n",
              "      const a = document.createElement('a');\n",
              "      a.href = window.URL.createObjectURL(blob);\n",
              "      a.download = filename;\n",
              "      div.appendChild(a);\n",
              "      a.click();\n",
              "      div.remove();\n",
              "    }\n",
              "  "
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        },
        {
          "output_type": "display_data",
          "data": {
            "application/javascript": [
              "download(\"download_3c7d6197-cb10-4c80-8600-0c2e093d82fe\", \"submission.zip\", 949458)"
            ],
            "text/plain": [
              "<IPython.core.display.Javascript object>"
            ]
          },
          "metadata": {
            "tags": []
          }
        }
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "izLlUVM4XIWv"
      },
      "source": [
        "Now you can head over to the [submission system](https://competitions.codalab.org/competitions/33512?secret_key=10f23c1f-9c86-4a7a-8406-d85b0a0713f2#participate) and upload your submission. Good luck :)"
      ]
    }
  ]
}
